#pragma once
#include <atomic>
#include <memory>
#include <mutex>
#include <string>
#include <condition_variable>
#include <map>

#include "ring_buffer.h"

#define DISABLE_COPY_AND_ASSIGN(ClassType) \
  ClassType(const ClassType &) = delete;   \
  ClassType &operator=(const ClassType &) = delete;

constexpr uint32_t kDefaultRingBuffer = 2000 * 1000;
constexpr uint32_t kBatchMaxLen = 5 * 1024 * 1024;  // 5 MB
constexpr uint32_t kMaxWaitTimeUs = 100 * 1000;
constexpr uint32_t kMaxWaitTimes = 10;

enum class OpRangeDataType {
  OP_RANGE_DATA = 1,
  IS_ASYNC = 2,
  NAME = 3,
  INPUT_DTYPES = 4,
  INPUT_SHAPE = 5,
  STACK = 6,
  MODULE_HIERARCHY = 7,
  EXTRA_ARGS = 8,
  CUSTOM_INFO = 9,
  RESERVED = 30,
};

struct BaseReportData {
  int32_t device_id{0};
  std::string tag;
  BaseReportData(int32_t device_id, std::string tag) : device_id(device_id), tag(std::move(tag)) {}
  virtual ~BaseReportData() = default;
  virtual std::vector<uint8_t> encode() = 0;
  virtual void preprocess() = 0;
};

struct OpRangeData : BaseReportData {
  int64_t start_ns{0};
  int64_t end_ns{0};
  int64_t sequence_number{0};
  uint64_t process_id{0};
  uint64_t start_thread_id{0};
  uint64_t end_thread_id{0};
  uint64_t forward_thread_id{0};
  bool is_async{false};
  std::string name;
  std::vector<std::string> input_dtypes;
  std::vector<std::vector<int64_t>> input_shapes;
  std::vector<std::string> stack;
  std::vector<std::string> module_hierarchy;
  uint64_t flow_id{0};
  int8_t level{-1};
  std::map<std::string, std::string> custom_info{};
  uint64_t step{0};
  OpRangeData(int64_t start_ns, int64_t end_ns, int64_t sequence_number, uint64_t process_id, uint64_t start_thread_id,
              uint64_t end_thread_id, uint64_t forward_thread_id, bool is_async, std::string name,
              std::vector<std::string> stack, uint64_t flow_id, int32_t device_id, uint64_t step, int8_t level,
              const std::map<std::string, std::string> &custom_info)
      : BaseReportData(device_id, "op_range_" + std::to_string(device_id)),
        start_ns(start_ns),
        end_ns(end_ns),
        sequence_number(sequence_number),
        process_id(process_id),
        start_thread_id(start_thread_id),
        end_thread_id(end_thread_id),
        forward_thread_id(forward_thread_id),
        is_async(is_async),
        name(std::move(name)),
        stack(std::move(stack)),
        flow_id(flow_id),
        level(level),
        custom_info(custom_info),
        step(step) {}

  OpRangeData(int64_t start_ns, int64_t end_ns, uint64_t start_thread_id, std::string name, int32_t device_id)
      : BaseReportData(device_id, "op_range_" + std::to_string(device_id)),
        start_ns(start_ns),
        end_ns(end_ns),
        start_thread_id(start_thread_id),
        name(std::move(name)) {}

  std::vector<uint8_t> encode();
  void preprocess();
};


class DataDumper {
 public:
  void Init(const std::string &path, int32_t rank_id, size_t capacity = kDefaultRingBuffer);
  void UnInit();
  void Report(std::unique_ptr<BaseReportData> data);
  void Start();
  void Stop();
  void Flush();

  static DataDumper &GetInstance();

 private:
  void Dump(const std::map<std::string, std::vector<uint8_t>> &dataMap);
  void Run();
  void GatherAndDumpData();

 private:
  DataDumper();
  virtual ~DataDumper();

  std::string path_;
  int32_t rank_id_{0};
  std::atomic<bool> start_;
  std::atomic<bool> init_;
  std::atomic<bool> is_flush_{false};
  RingBuffer<std::unique_ptr<BaseReportData>> data_chunk_buf_;
  std::map<std::string, FILE *> fd_map_;
  std::mutex flush_mutex_;

  DISABLE_COPY_AND_ASSIGN(DataDumper);
};